#!/usr/bin/env python
# coding: utf-8

# ## Mount drive and set current directory

# In[1]:


import os
os.chdir("..")
print(os.getcwd())


# ## Random Pick

# In[2]:


from random import choices

def random_indices(max_obs, num_obs):
  return choices(range(max_obs), k=num_obs)


# ## K-means Clustering

# In[3]:


from fast_pytorch_kmeans import KMeans
import torch
from random import choices

def kmeans_indices(obs, num_obs):
  kmeans = KMeans(n_clusters=num_obs, mode='euclidean', verbose=1)
  labels = kmeans.fit_predict(obs)
  label_idx_dict = {}
  for index, label in enumerate(labels):
    label = label.item()
    if label in label_idx_dict:
      label_idx_dict[label].append(index)
    else:
      label_idx_dict[label] = [index]
  indices = [choices(label_idx_dict[key], k=1)[0] for key in label_idx_dict]
  if len(indices) < num_obs:
    more_indices = [choices(label_idx_dict[key], k=1)[0] for key in label_idx_dict]
    for idx in more_indices:
      if idx not in indices:
        indices.append(idx)
      if len(indices) == num_obs:
        break
  return indices


# ## Farthest Point Sampler

# In[5]:


import numpy as np

def farthestPointSampler(dist_matrix, num_obs):
  indices = np.zeros(num_obs, dtype=np.int64)
  # select two farthest points
  indices[0], indices[1] = np.unravel_index(dist_matrix.argmax(), dist_matrix.shape)
  for i in range(2, num_obs):
    # maximize minimum distance to all points in indices
    sorted_indices = np.argsort(np.min(dist_matrix[indices[:i],:], axis=0))[::-1]
    #sorted_indices = np.setdiff1d(sorted_indices, indices[:i])
    sorted_indices = sorted_indices[~np.in1d(sorted_indices, indices[:i])]
    indices[i] = sorted_indices[0]
  return indices


# ## Greedy farthest point based on KL Divergence

# In[7]:


# Basis: normal distribution in all embeddings


# In[8]:


import pandas as pd


# In[11]:


from scipy.stats import norm
from itertools import combinations_with_replacement
from itertools import chain
import numpy as np

def fit_norm(obs):
  mu_list = []
  sd_list = []
  n_obs = obs.shape[0]
  for i in range(n_obs):
    mu, sd = norm.fit(obs[i,:])
    mu_list.append(mu)
    sd_list.append(sd)
  return mu_list, sd_list

def gaussian_kld(mu1, sd1, mu2, sd2):
  return np.log(sd2/sd1) + ((sd1**2 + (mu1-mu2)**2) / (2*(sd2**2))) - 0.5

def get_kld_matrix(mu_list, sd_list, dataset_name="", embed_type=""):
  dshape = len(mu_list)
  kld_matrix = np.zeros((dshape, dshape))
  looper = combinations_with_replacement(range(dshape), 2)
  for i, j in looper:
    kld_ij = gaussian_kld(mu_list[i], sd_list[i], mu_list[j], sd_list[j]) + gaussian_kld(mu_list[j], sd_list[j], mu_list[i], sd_list[i])
    kld_matrix[i][j] = kld_ij
    kld_matrix[j][i] = kld_ij
  print('Saving kld matrix...')
  np.save("data/output/" +dataset_name+'_kld_'+embed_type, kld_matrix)
  #return kld_matrix
  return dataset_name+'_kld_'+embed_type+'.npy'


# ## Greedy Farthest Point Sampler using Kolmogorov-Smirnov measure

# In[13]:


def ks_2samp_faster(data1, data2):
    data_all = np.concatenate([data1, data2])
    # using searchsorted solves equal data problem
    cdf1 = np.searchsorted(data1, data_all, side='right') / data1.shape[0]
    cdf2 = np.searchsorted(data2, data_all, side='right') / data2.shape[0]
    cddiffs = cdf1 - cdf2
    minS = np.clip(-np.min(cddiffs), 0, 1)  # Ensure sign of minS is not negative.
    maxS = np.max(cddiffs)
    d = max(minS, maxS)
    return d


# In[14]:


def get_ks_matrix(obs, dataset_name="", embed_type=""):
  num_obs = len(obs)
  ks_matrix = np.zeros((num_obs, num_obs))
  obs = np.sort(obs, axis=1)
  for i in range(num_obs):
    for j in range(i, num_obs):
      val = ks_2samp_faster(obs[i], obs[j])
      ks_matrix[i][j] = val
      ks_matrix[j][i] = val
  print('Saving ks matrix...')
  np.save("data/output/" +dataset_name+'_ks_'+embed_type, ks_matrix)
  return dataset_name+'_ks_'+embed_type+'.npy'


# ## Greedy Farthest Point Sampler using Cosine Distance Matrix
# 

# In[15]:


from scipy.spatial.distance import cosine
def get_cos_matrix(obs, dataset_name="", embed_type=""):
  num_obs = len(obs)
  cos_matrix = np.zeros((num_obs, num_obs))
  for i in range(num_obs):
    for j in range(i, num_obs):
      val = cosine(obs[i], obs[j])
      cos_matrix[i][j] = val
      cos_matrix[j][i] = val
  print('Saving cosine matrix...')
  np.save("data/output/" +dataset_name+'_cos_'+embed_type, cos_matrix)
  return dataset_name+'_cos_'+embed_type+'.npy'


# ## D-Optimality

# In[17]:


import numpy as np

def variancedist(W, S):
    try:
        R = np.linalg.inv(S.T @ S)
    except:
        R = np.linalg.inv(S.T @ S + np.eye(S.shape[1]) * 0.001)
    D = torch.sum( (W @ R) * W, dim = 1, dtype = torch.float)
    return D

def dopt(topicmat, k):
    index = [np.random.choice(topicmat.shape[0])]
    rows = np.array(range(topicmat.shape[0]))
    rows = np.delete(rows, index)
    S = topicmat[index,:]
    W = topicmat[~np.isin(range(topicmat.shape[0]), index)]
    # print(S.shape, W.shape)
    while len(index) < k:
        i = np.argmax(variancedist(W, S))
        S = np.vstack((S, W[i,:]))
        W = np.delete(W, i, axis=0)
        index.append(rows[i])
        rows = np.delete(rows, i)
    return index

